{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Improving function approximation by adjusting tuning curves\n",
    "\n",
    "This tutorial shows how adjusting the tuning curves of neurons\n",
    "can help implement specific functions with Nengo.\n",
    "As an example we will try to to compute\n",
    "the Heaviside step function,\n",
    "which is 1 for all $x > 0$ and 0 otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import nengo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The naive approach\n",
    "\n",
    "As a first pass, we can try to implement the Heaviside step function\n",
    "using an ensemble with default parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_neurons = 150\n",
    "duration = 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stimulus_fn(t):\n",
    "    return (2. * t / duration) - 1.\n",
    "\n",
    "\n",
    "def heaviside(x):\n",
    "    return x > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Network() as model_naive:\n",
    "    stimulus = nengo.Node(stimulus_fn)\n",
    "    ens = nengo.Ensemble(n_neurons=n_neurons, dimensions=1)\n",
    "    output = nengo.Node(size_in=1)\n",
    "\n",
    "    nengo.Connection(stimulus, ens)\n",
    "    nengo.Connection(ens, output, function=heaviside)\n",
    "\n",
    "    p_naive = nengo.Probe(output, synapse=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model_naive) as sim_naive:\n",
    "    sim_naive.run(duration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = sim_naive.trange()\n",
    "plt.figure()\n",
    "plt.plot(t, sim_naive.data[p_naive], label=\"naive\")\n",
    "plt.plot(t, heaviside(stimulus_fn(t)), '--', c='black', label=\"ideal\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"Output\")\n",
    "plt.legend(loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that this approach does work,\n",
    "but there is room for improvement."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Investigating the tuning curves\n",
    "\n",
    "Let us take a look at\n",
    "the tuning curves of the neurons in the ensemble."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(*nengo.utils.ensemble.tuning_curves(ens, sim_naive))\n",
    "plt.xlabel(\"Input\")\n",
    "plt.ylabel(\"Firing rate [Hz]\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "About half of these neurons are tuned to fire more for smaller values.\n",
    "But these values are not relevant\n",
    "for the Heaviside step function,\n",
    "since the output is always 0\n",
    "when input is less than 0.\n",
    "We can change all neurons to be tuned\n",
    "to fire more for larger values\n",
    "by setting all the encoders to be positive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Network() as model_pos_enc:\n",
    "    stimulus = nengo.Node(stimulus_fn)\n",
    "    ens = nengo.Ensemble(n_neurons=n_neurons, dimensions=1,\n",
    "                         encoders=nengo.dists.Choice([[1.]]))\n",
    "    output = nengo.Node(size_in=1)\n",
    "\n",
    "    nengo.Connection(stimulus, ens)\n",
    "    nengo.Connection(ens, output, function=heaviside)\n",
    "\n",
    "    p_pos_enc = nengo.Probe(output, synapse=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model_pos_enc) as sim_pos_enc:\n",
    "    sim_pos_enc.run(duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The resulting tuning curves:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(*nengo.utils.ensemble.tuning_curves(ens, sim_pos_enc))\n",
    "plt.xlabel(\"Input\")\n",
    "plt.ylabel(\"Firing rate [Hz]\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And the resulting decoded Heaviside step function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = sim_pos_enc.trange()\n",
    "plt.figure()\n",
    "plt.plot(t, sim_naive.data[p_naive], label=\"naive\")\n",
    "plt.plot(t, sim_pos_enc.data[p_pos_enc], label=\"pos. enc.\")\n",
    "plt.plot(t, heaviside(stimulus_fn(t)), '--', c='black', label=\"ideal\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"Output\")\n",
    "plt.legend(loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compared to the naive approach,\n",
    "the representation of outputs lower than 1 is less noisy,\n",
    "but otherwise there is little improvement.\n",
    "Even though the tuning curves are all positive,\n",
    "they are still covering a lot of irrelevant parts of the input space.\n",
    "Because all outputs should be 0 when input is less than 0,\n",
    "there is no need to have neurons tuned to inputs less than 0.\n",
    "Let's shift all the intercepts to the range $(0.0, 1.0)$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intercept distributions\n",
    "\n",
    "Not only can the range of intercepts be important,\n",
    "but also the distribution of intercepts.\n",
    "Let us take a look at the Heaviside step function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.linspace(-1, 1, 100)\n",
    "plt.figure()\n",
    "plt.plot(xs, heaviside(xs))\n",
    "plt.ylim(-0.1, 1.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function is mostly constant,\n",
    "except for the large jump at 0.\n",
    "The constant parts are easy to approximate\n",
    "and do not need a lot of neural resources,\n",
    "but the highly nonlinear jump\n",
    "requires more neural resources\n",
    "for an accurate representation.\n",
    "\n",
    "Let us compare the thresholding of a scalar in three ways:\n",
    "\n",
    "1. With a uniform distribution of intercepts (the default)\n",
    "2. With all intercepts at 0 (where we have the nonlinearity)\n",
    "3. With an exponential distribution\n",
    "\n",
    "The last approach is in between\n",
    "the two extremes of a uniform distribution\n",
    "and placing all intercepts at 0.\n",
    "It will distribute most intercepts close to 0,\n",
    "but some intercepts will still be at larger values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold = 0.\n",
    "\n",
    "with nengo.Network() as model_dists:\n",
    "    stimulus = nengo.Node(stimulus_fn)\n",
    "    ens_uniform = nengo.Ensemble(\n",
    "        n_neurons=n_neurons, dimensions=1,\n",
    "        encoders=nengo.dists.Choice([[1]]),\n",
    "        intercepts=nengo.dists.Uniform(threshold, 1.))\n",
    "    ens_fixed = nengo.Ensemble(\n",
    "        n_neurons=n_neurons, dimensions=1,\n",
    "        encoders=nengo.dists.Choice([[1]]),\n",
    "        intercepts=nengo.dists.Choice([threshold]))\n",
    "    ens_exp = nengo.Ensemble(\n",
    "        n_neurons=n_neurons, dimensions=1,\n",
    "        encoders=nengo.dists.Choice([[1]]),\n",
    "        intercepts=nengo.dists.Exponential(0.15, threshold, 1.))\n",
    "\n",
    "    out_uniform = nengo.Node(size_in=1)\n",
    "    out_fixed = nengo.Node(size_in=1)\n",
    "    out_exp = nengo.Node(size_in=1)\n",
    "\n",
    "    nengo.Connection(stimulus, ens_uniform)\n",
    "    nengo.Connection(stimulus, ens_fixed)\n",
    "    nengo.Connection(stimulus, ens_exp)\n",
    "    nengo.Connection(ens_uniform, out_uniform, function=heaviside)\n",
    "    nengo.Connection(ens_fixed, out_fixed, function=heaviside)\n",
    "    nengo.Connection(ens_exp, out_exp, function=heaviside)\n",
    "\n",
    "    p_uniform = nengo.Probe(out_uniform, synapse=0.005)\n",
    "    p_fixed = nengo.Probe(out_fixed, synapse=0.005)\n",
    "    p_exp = nengo.Probe(out_exp, synapse=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model_dists) as sim_dists:\n",
    "    sim_dists.run(duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us look at the tuning curves first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 4))\n",
    "\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.plot(*nengo.utils.ensemble.tuning_curves(ens_uniform, sim_dists))\n",
    "plt.xlabel(\"Input\")\n",
    "plt.ylabel(\"Firing rate [Hz]\")\n",
    "plt.title(\"Uniform intercepts\")\n",
    "\n",
    "plt.subplot(1, 3, 2)\n",
    "plt.plot(*nengo.utils.ensemble.tuning_curves(ens_fixed, sim_dists))\n",
    "plt.xlabel(\"Input\")\n",
    "plt.ylabel(\"Firing rate [Hz]\")\n",
    "plt.title(\"Fixed intercepts\")\n",
    "\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.plot(*nengo.utils.ensemble.tuning_curves(ens_exp, sim_dists))\n",
    "plt.xlabel(\"Input\")\n",
    "plt.ylabel(\"Firing rate [Hz]\")\n",
    "plt.title(\"Exponential intercept distribution\")\n",
    "\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let us look at how these three ensembles\n",
    "approximate the thresholding function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = sim_dists.trange()\n",
    "plt.figure()\n",
    "plt.plot(t, sim_naive.data[p_naive], label='naive', c='gray')\n",
    "plt.plot(t, sim_dists.data[p_uniform], label='Uniform intercepts')\n",
    "plt.plot(t, sim_dists.data[p_fixed], label='Fixed intercepts')\n",
    "plt.plot(t, sim_dists.data[p_exp], label='Exp. intercept dist.')\n",
    "plt.plot(t, heaviside(stimulus_fn(t)), '--', c='black', label=\"ideal\")\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('Output')\n",
    "plt.legend(loc='best')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the fixed intercepts\n",
    "produce slightly higher decoded values close to the threshold,\n",
    "but the slope is lower than for uniform intercepts.\n",
    "The best approximation of the thresholding\n",
    "is done with the exponential intercept distribution.\n",
    "Here we get a quick rise to 1 at the threshold\n",
    "and a fairly constant representation of 1\n",
    "for inputs sufficiently above 0.\n",
    "All three distributions perfectly represent values below 0.\n",
    "\n",
    "Nengo provides the `ThresholdingEnsemble` preset\n",
    "to make it easier to assign intercepts\n",
    "according to the exponential distribution,\n",
    "and also adjusts the encoders and evaluation points accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Network() as model_final:\n",
    "    stimulus = nengo.Node(stimulus_fn)\n",
    "    with nengo.presets.ThresholdingEnsembles(0.):\n",
    "        ens = nengo.Ensemble(n_neurons=n_neurons, dimensions=1)\n",
    "    output = nengo.Node(size_in=1)\n",
    "\n",
    "    nengo.Connection(stimulus, ens)\n",
    "    nengo.Connection(ens, output, function=heaviside)\n",
    "\n",
    "    p_final = nengo.Probe(output, synapse=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model_final) as sim_final:\n",
    "    sim_final.run(duration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = sim_final.trange()\n",
    "plt.figure()\n",
    "plt.plot(t, sim_final.data[p_final], label=\"final\")\n",
    "plt.plot(t, heaviside(stimulus_fn(t)), '--', c='black', label=\"ideal\")\n",
    "plt.xlabel(\"t\")\n",
    "plt.ylabel(\"Output\")\n",
    "plt.legend(loc='best');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## The takeaway\n",
    "\n",
    "Adjusting ensemble parameters in the right way\n",
    "can sometimes help in implementing functions more acurately in neurons.\n",
    "Think about how your function maps from\n",
    "the input vector space to the output vector space,\n",
    "and consider ways to modify ensemble parameters\n",
    "to better cover important parts\n",
    "of the input vector space."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
